!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastbook import *
from fastai.vision.widgets import *
     |████████████████████████████████| 720 kB 4.2 MB/s 
     |████████████████████████████████| 46 kB 4.2 MB/s 
     |████████████████████████████████| 1.2 MB 42.6 MB/s 
     |████████████████████████████████| 189 kB 36.3 MB/s 
     |████████████████████████████████| 56 kB 4.6 MB/s 
     |████████████████████████████████| 51 kB 327 kB/s 
Mounted at /content/gdrive
path = Path('/content/gdrive/MyDrive/tanks')

Download imges for training

tank_types = 'merkava mk4','M1 Abrams','water'
path = Path('tanks')
if not path.exists():
    path.mkdir()
    for o in tank_types:
        dest = (path/o)
        dest.mkdir(exist_ok=True)
        urls = search_images_ddg(f'{o} tank', max_images=150)
        download_images(dest, urls=urls)
fns = get_image_files(path)
failed = verify_images(fns)
failed.map(Path.unlink);

Preparing the data to the model

tanks = DataBlock(
    blocks=(ImageBlock, CategoryBlock), 
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parent_label,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = tanks.dataloaders(path)
dls.valid.show_batch(max_n=4, nrows=1)

Training the model and clean the data

learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
epoch train_loss valid_loss error_rate time
0 1.437875 0.810209 0.333333 00:54
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss error_rate time
0 0.536968 0.407019 0.142857 00:24
1 0.393917 0.193613 0.066667 00:25
2 0.297317 0.139224 0.057143 00:24
3 0.249617 0.127603 0.047619 00:24
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
def plot_top_losses_fix(interp, k, largest=True, **kwargs):
        losses,idx = interp.top_losses(k, largest)
        if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,)
        if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs)
        else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx]))
        b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,)))
        x,y,its = interp.dl._pre_show_batch(b, max_n=k)
        b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,)))
        x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            #plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), L(self.preds).itemgot(idx), losses,  **kwargs)
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses,  **kwargs)
        #TODO: figure out if this is needed
        #its None means that a batch knows how to show itself as a whole, so we pass x, x1
        #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)
plot_top_losses_fix(interp, 10, nrows=10)
cleaner = ImageClassifierCleaner(learn)
cleaner
 
for idx, cat in cleaner.change():
  f = cleaner.fns[idx]
  os.rename(f, f.parent/(f.parent.name+f.stem+f.suffix))
  shutil.move(str(f.parent/(f.parent.name+f.stem+f.suffix)), path/cat )

POC with user interface

Using the Model for Inference

learn.export()
path = Path()
path.ls(file_exts='.pkl')
(#1) [Path('export.pkl')]
learn_inf = load_learner(path/'export.pkl')

Making multi label classifier on the same data---->own loss

lr = cnn_learner(dls, resnet18, loss_func=F.binary_cross_entropy_with_logits, metrics=partial(accuracy_multi, thresh=0.2) )
lr.fine_tune(4)

Simple UI demo

 
path = Path()
learn_inf = load_learner(path/'multi.pkl')
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')
def on_click_classify(change):
    img = PILImage.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx,probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'

btn_run.on_click(on_click_classify)
VBox([widgets.Label('Select your tank!'), 
      btn_upload, btn_run, out_pl, lbl_pred])

Addtional validation

learnerm.predict('/content/126.JPG')
((#1) ['water'],
 TensorBase([False, False, False,  True]),
 TensorBase([0.4033, 0.0402, 0.1413, 0.9869]))
multi_learner2.predict('/content/DUDU.jpg')
((#0) [],
 TensorBase([-2.2279, -0.8042,  0.6245]),
 TensorBase([-2.2279, -0.8042,  0.6245]))
a = torch.tensor([-2.2279, -0.8042,  0.6245])
a.sigmoid()
tensor([0.0973, 0.3091, 0.6512])
multi_learner2.predict('/content/TetroWtank.jpg')
((#0) [],
 TensorBase([-5.6009, -2.1743,  4.8451]),
 TensorBase([-5.6009, -2.1743,  4.8451]))
a = torch.tensor([-5.6009, -2.1743,  4.8451])
a.sigmoid()
tensor([0.0037, 0.1021, 0.9922])
create_cnn_model?
cnn_learner??
multi_learner2.predict??
multi_learner.predict('/content/DUDU.jpg')
((#2) ['merkava mk4','water'],
 TensorBase([False,  True,  True]),
 TensorBase([0.1738, 0.9022, 0.5532]))
multi_learner.predict('/content/manytanks.jpg')
multi_learner.predict('/content/merk_abrms2.jpeg')
((#2) ['M1 Abrams','merkava mk4'],
 TensorBase([ True,  True, False]),
 TensorBase([0.6895, 0.9527, 0.0021]))
multi_learner.predict('/content/merkav_abrams.jpg')
((#1) ['M1 Abrams'],
 TensorBase([ True, False, False]),
 TensorBase([0.9377, 0.4514, 0.0021]))
multi_learner.predict('/content/mek_abs_3.jpg')
((#1) ['M1 Abrams'],
 TensorBase([ True, False, False]),
 TensorBase([1.0000, 0.0072, 0.0388]))
multi_learner.predict('/content/TetroWtank.jpg')
((#1) ['water'],
 TensorBase([False, False,  True]),
 TensorBase([0.0694, 0.1050, 0.9140]))
a, b, c = lr.predict('/content/merk_abrms2.jpeg')
c.sigmoid()
TensorBase([0.2893, 0.9868, 0.0014])

Try Multi label model on the sdata

def parentlabel(x):
  return [x.parent.name]
tanks = DataBlock(
    blocks=(ImageBlock, MultiCategoryBlock), 
    #MultiCategoryBlock(add_na=True)
    get_items=get_image_files, 
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=parentlabel,
    item_tfms=RandomResizedCrop(224, min_scale=0.5),
    batch_tfms=aug_transforms())
dls = tanks.dataloaders(path)
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1051: UserWarning: torch.solve is deprecated in favor of torch.linalg.solveand will be removed in a future PyTorch release.
torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:766.)
  ret = func(*args, **kwargs)
dls.valid.show_batch(nrows=1, ncols=6)
multi_learner = cnn_learner(dls, resnet18, metrics=accuracy_multi) #threshold=0.5, Sigmod=True--->partial(accuracy_multi, thresh=0.95
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
multi_learner.fine_tune(4)
epoch train_loss valid_loss accuracy_multi time
0 0.932164 0.447739 0.820513 02:26
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss accuracy_multi time
0 0.448215 0.294454 0.868590 00:24
1 0.357718 0.199940 0.916667 00:25
2 0.282826 0.181428 0.935897 00:24
3 0.237339 0.176398 0.935897 00:24
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "

Fixing the top lose

def plot_top_losses_fix(interp, k, largest=True, **kwargs):
        losses,idx = interp.top_losses(k, largest)
        if not isinstance(interp.inputs, tuple): interp.inputs = (interp.inputs,)
        if isinstance(interp.inputs[0], Tensor): inps = tuple(o[idx] for o in interp.inputs)
        else: inps = interp.dl.create_batch(interp.dl.before_batch([tuple(o[i] for o in interp.inputs) for i in idx]))
        b = inps + tuple(o[idx] for o in (interp.targs if is_listy(interp.targs) else (interp.targs,)))
        x,y,its = interp.dl._pre_show_batch(b, max_n=k)
        b_out = inps + tuple(o[idx] for o in (interp.decoded if is_listy(interp.decoded) else (interp.decoded,)))
        x1,y1,outs = interp.dl._pre_show_batch(b_out, max_n=k)
        if its is not None:
            #plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), L(self.preds).itemgot(idx), losses,  **kwargs)
            plot_top_losses(x, y, its, outs.itemgot(slice(len(inps), None)), interp.preds[idx], losses,  **kwargs)
        #TODO: figure out if this is needed
        #its None means that a batch knows how to show itself as a whole, so we pass x, x1
        #else: show_results(x, x1, its, ctxs=ctxs, max_n=max_n, **kwargs)
multi_learner.dls.vocab
['M1 Abrams', 'merkava mk4', 'water']
intrpt = ClassificationInterpretation.from_learner(multi_learner)
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
plot_top_losses_fix(intrpt, 10, nrows=10)
multi_learner.metrics = partial(accuracy_multi, thresh=0.2)
multi_learner.validate()
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
(#2) [0.17249105870723724,0.9070513248443604]
preds, targs = multi_learner.get_preds()
preds.shape, targs.shape
(torch.Size([104, 3]), torch.Size([104, 3]))
preds[0], targs[0]
(TensorBase([0.9940, 0.0058, 0.0028]), TensorMultiCategory([1., 0., 0.]))
accuracy_multi(preds, targs, thresh=0.5, sigmoid=False)
TensorBase(0.9423)
th = [0.1, 0.3, 0.5, 0.7, 0.9]
[accuracy_multi(preds, targs, thresh=x, sigmoid=False) for x in th]
[TensorBase(0.8750),
 TensorBase(0.9167),
 TensorBase(0.9423),
 TensorBase(0.9487),
 TensorBase(0.9487)]
xs = torch.linspace(0.05,0.95,29)
accs = [accuracy_multi(preds, targs, thresh=i, sigmoid=False) for i in xs]
plt.plot(xs,accs);
multi_learner.loss_func.thresh=0.2
multi_learner.loss_func.activation
<bound method BCEWithLogitsLossFlat.activation of FlattenedLoss of BCEWithLogitsLoss()>
multi_learner.loss_func
FlattenedLoss of BCEWithLogitsLoss()
multi_learner.loss_func.decodes
<bound method BCEWithLogitsLossFlat.decodes of FlattenedLoss of BCEWithLogitsLoss()>
multi_learner.export('multi')

same model but with assign loss function

# This is formatted as code
ls=nn.BCEWithLogitsLoss()
multi_learner2 = cnn_learner(dls, resnet18, loss_func=ls, metrics=accuracy_multi)
multi_learner2.fine_tune(4)
epoch train_loss valid_loss accuracy_multi time
0 0.872250 0.800857 0.753205 00:24
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
epoch train_loss valid_loss accuracy_multi time
0 0.430897 0.460585 0.833333 00:25
1 0.379967 0.258608 0.916667 00:25
2 0.311630 0.176938 0.939103 00:25
3 0.259250 0.157732 0.935897 00:25
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
multi_learner2.get_preds(with_decoded=False)
/usr/local/lib/python3.7/dist-packages/PIL/Image.py:960: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
  "Palette images with Transparency expressed in bytes should be "
(TensorBase([[ 1.1054e+01, -4.3865e+00, -3.8246e+00],
         [-4.8288e+00, -6.0278e+00,  1.2231e+01],
         [-4.8429e+00,  8.0363e-02,  8.5969e-01],
         [-1.1705e+00, -5.7417e+00,  1.0362e+01],
         [-4.0937e+00, -2.2996e+00,  9.4857e+00],
         [-3.8556e+00, -4.1440e+00,  1.2390e+01],
         [-3.8271e+00,  6.5946e+00, -4.1401e+00],
         [ 1.0715e+01, -5.2831e+00, -6.1691e+00],
         [-1.2443e+00,  4.7203e+00, -5.1741e+00],
         [-3.3200e+00, -2.3427e+00,  7.2592e+00],
         [-3.9163e+00, -2.3793e+00,  9.4112e+00],
         [-3.1000e+00, -2.0570e+00,  6.1245e+00],
         [-2.9966e+00, -2.6585e+00,  1.0127e+01],
         [ 1.2771e+01, -6.1638e+00, -5.5002e+00],
         [-2.3349e+00, -2.4252e+00,  8.9080e+00],
         [ 1.2209e+01, -8.3553e+00, -6.7067e+00],
         [-5.0397e+00, -2.6586e+00,  1.0207e+01],
         [ 1.0072e+01, -2.9396e+00, -6.1747e+00],
         [-4.9629e+00, -5.0143e+00,  1.2457e+01],
         [-2.4362e+00, -3.4457e+00,  1.1917e+01],
         [-3.3012e+00, -4.1240e+00,  9.7303e+00],
         [-2.6708e+00,  6.3316e+00, -5.4875e+00],
         [-4.2726e+00, -4.0000e+00,  8.1809e+00],
         [-4.7021e+00, -6.8715e+00,  1.3693e+01],
         [ 1.1156e+01, -7.3597e+00, -4.5801e+00],
         [-5.1651e+00, -4.5510e+00,  1.7812e+01],
         [-2.6461e+00, -5.3205e+00,  1.2895e+01],
         [ 1.0624e+01, -6.1674e+00, -7.2841e+00],
         [-3.9826e+00, -3.9549e+00,  9.2268e+00],
         [ 1.0253e+01, -4.5484e+00, -6.0170e+00],
         [-2.3421e+00,  5.4049e+00, -3.4479e+00],
         [ 3.7226e+00, -5.4512e+00, -9.5822e-01],
         [-3.8336e+00,  7.4639e+00, -2.5219e+00],
         [-4.4530e+00, -4.9878e+00,  1.5383e+01],
         [ 8.7228e+00, -3.2869e+00, -6.2374e+00],
         [-4.7827e+00,  1.1229e+01, -7.6850e+00],
         [-3.2945e+00, -3.4483e+00,  6.4119e+00],
         [ 9.8750e+00, -5.1142e+00, -8.9492e+00],
         [ 9.5270e+00, -9.3801e+00, -7.7053e+00],
         [ 1.6784e+01, -8.5052e+00, -8.4742e+00],
         [-5.0142e+00, -8.1198e+00,  1.0895e+01],
         [ 1.3255e+00,  1.6833e+00, -2.7004e+00],
         [ 1.7176e+01, -1.0126e+01, -9.0501e+00],
         [-3.8401e+00, -7.1575e+00,  1.6060e+01],
         [-3.3630e+00, -3.8490e+00,  1.2150e+01],
         [ 8.3805e-01,  3.2154e-01, -4.7388e+00],
         [-4.5516e+00,  8.7604e+00, -6.0815e+00],
         [-1.9349e+00, -5.9613e+00,  1.3269e+01],
         [-5.3436e+00, -4.0963e+00,  9.9250e+00],
         [ 1.0194e+01, -3.8770e+00, -5.9039e+00],
         [ 1.1517e+01, -6.8700e+00, -8.5469e+00],
         [-3.3945e+00,  6.4746e+00, -4.0815e+00],
         [ 6.8645e+00, -2.5304e+00, -4.9125e+00],
         [-4.3031e+00,  7.2538e+00, -3.6068e+00],
         [ 7.9635e+00, -6.2923e+00, -4.8670e+00],
         [-2.2701e+00, -5.0116e+00,  8.7348e+00],
         [ 1.2780e+01, -3.8046e+00, -7.6140e+00],
         [ 8.5537e-03,  1.8284e+00, -5.9196e+00],
         [-3.5084e-01,  1.1758e+00, -4.3755e+00],
         [ 7.3451e+00, -4.7630e+00, -5.5742e+00],
         [-4.3683e+00, -6.3360e+00,  1.4351e+01],
         [-6.8580e+00, -4.3077e-01,  6.4575e+00],
         [ 1.5667e+01, -4.7154e+00, -8.2658e+00],
         [ 9.2235e+00, -5.7826e+00, -4.5370e+00],
         [-7.4300e-01,  4.6694e+00, -6.4751e+00],
         [-4.1144e+00, -3.3529e+00,  8.9280e+00],
         [ 1.3519e+01, -5.3617e+00, -6.0206e+00],
         [ 1.1722e+01, -6.0788e+00, -7.7755e+00],
         [-5.9252e+00,  1.0020e+01, -6.6066e+00],
         [-5.0563e+00, -7.5295e+00,  1.5377e+01],
         [ 3.6416e-01,  7.0579e+00, -8.8655e+00],
         [-4.4168e+00, -5.0467e+00,  8.1200e+00],
         [ 1.2095e+01, -4.9284e+00, -5.3975e+00],
         [-3.5256e+00, -2.0509e+00,  7.1033e+00],
         [ 3.7408e+00, -1.7004e-01, -7.8555e+00],
         [ 2.8994e+00,  1.7881e+00, -4.6694e+00],
         [ 8.7549e+00, -5.0893e+00, -6.9562e+00],
         [ 1.8290e+01, -6.4178e+00, -9.8740e+00],
         [ 1.1128e+01, -4.8088e+00, -9.0944e+00],
         [ 6.3693e+00, -3.6620e+00, -6.3378e+00],
         [-2.6944e+00,  4.9052e+00, -6.6235e+00],
         [-5.1428e+00, -5.0365e+00,  1.1248e+01],
         [ 1.4320e+00,  7.9101e-01, -5.2702e+00],
         [-1.6046e+00,  4.2638e+00, -4.9068e+00],
         [-4.9069e+00,  7.9584e+00, -5.2792e+00],
         [-3.2035e+00, -1.7164e+00,  1.0915e+01],
         [ 4.9181e+00, -2.4524e+00, -3.6611e+00],
         [-5.2584e+00, -4.5208e+00,  1.1430e+01],
         [-7.7838e-01,  4.1009e+00, -5.6056e+00],
         [ 1.1177e+01, -5.5141e+00, -5.9799e+00],
         [-3.1584e+00, -4.7642e+00,  9.7285e+00],
         [-4.4772e+00, -2.2261e+00,  9.7936e+00],
         [ 2.0223e+00,  2.5525e+00, -6.6029e+00],
         [-4.8925e+00, -3.6348e+00,  1.1941e+01],
         [ 2.8736e+00,  8.6336e-01, -5.2133e+00],
         [-3.2290e+00,  5.8168e+00, -4.3457e+00],
         [ 8.9244e-01,  5.9619e+00, -8.6835e+00],
         [ 2.8697e-02,  3.5361e+00, -5.8552e+00],
         [ 4.0667e+00,  2.7531e+00, -6.2681e+00],
         [-3.0422e+00, -4.2448e+00,  1.2542e+01],
         [-4.6983e+00, -1.7849e+00,  1.2067e+01],
         [ 4.2613e+00,  5.1500e-01, -7.1514e+00],
         [-3.4810e+00, -6.4663e+00,  1.6082e+01],
         [ 2.6145e-01, -3.9896e+00,  6.7493e+00]]),
 TensorMultiCategory([[1., 0., 0.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 1., 0.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 1., 0.],
         [0., 1., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [1., 0., 0.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 0., 1.],
         [0., 1., 0.],
         [0., 1., 0.],
         [0., 1., 0.],
         [0., 1., 0.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.],
         [1., 0., 0.],
         [0., 0., 1.],
         [0., 0., 1.]]))